{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate 2D Images using SDS\n",
    "\n",
    "This notebook demonstrates how to generate 2D images using SDS. It is a good way to test the guidance techniques."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim.lr_scheduler import LambdaLR\n",
    "import threestudio\n",
    "import gc\n",
    "import time\n",
    "import io\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from ipywidgets import interact, IntSlider, Output\n",
    "from IPython.display import display, clear_output\n",
    "from PIL import Image\n",
    "\n",
    "def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles: float = 0.5):\n",
    "\n",
    "    def lr_lambda(current_step):\n",
    "        if current_step < num_warmup_steps:\n",
    "            return float(current_step) / float(max(1, num_warmup_steps))\n",
    "        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n",
    "        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))\n",
    "\n",
    "    return LambdaLR(optimizer, lr_lambda, -1)\n",
    "\n",
    "def seed_everything(seed):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    \n",
    "# To specify the gpu you want to use, we recommend to start the jupyter server with CUDA_VISIBLE_DEVICES=<gpu_ids>.\n",
    "# threestudio.utils.base.get_device = lambda: torch.device('cuda:0') # hack the cuda device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"An astronaut riding a horse in space\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# stable diffusion \n",
    "config = {\n",
    "    'max_iters': 1000,\n",
    "    'seed': 42,\n",
    "    'scheduler': 'cosine',\n",
    "    'mode': 'latent',\n",
    "    'prompt_processor_type': 'stable-diffusion-prompt-processor',\n",
    "    'prompt_processor': {\n",
    "        'prompt': prompt,\n",
    "    },\n",
    "    'guidance_type': 'stable-diffusion-guidance',\n",
    "    'guidance': {\n",
    "        'half_precision_weights': False,\n",
    "        'guidance_scale': 100.,\n",
    "        'pretrained_model_name_or_path': 'runwayml/stable-diffusion-v1-5',\n",
    "        'grad_clip': None,\n",
    "        'view_dependent_prompting': False,\n",
    "    },\n",
    "    'image': {\n",
    "        'width': 64,\n",
    "        'height': 64,\n",
    "    }\n",
    "}\n",
    "\n",
    "# deepfloyd\n",
    "\n",
    "# config = {\n",
    "#     'max_iters': 1000,\n",
    "#     'seed': 42,\n",
    "#     'scheduler': 'cosine',\n",
    "#     'mode': 'rgb', # deepfloyd does not support latent optimization\n",
    "#     'prompt_processor_type': 'deep-floyd-prompt-processor',\n",
    "#     'prompt_processor': {\n",
    "#         'prompt': prompt,\n",
    "#     },\n",
    "#     'guidance_type': 'deep-floyd-guidance',\n",
    "#     'guidance': {\n",
    "#         'half_precision_weights': True,\n",
    "#         'guidance_scale': 7.,\n",
    "#         'pretrained_model_name_or_path': 'DeepFloyd/IF-I-XL-v1.0',\n",
    "#         'grad_clip': None,\n",
    "#         \"view_dependent_prompting\": False,\n",
    "#     },\n",
    "#     'image': {\n",
    "#         'width': 64,\n",
    "#         'height': 64,\n",
    "#     }\n",
    "# }\n",
    "\n",
    "seed_everything(config['seed'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# just need to rerun the cell when you change guidance or prompt_processor\n",
    "guidance = None\n",
    "prompt_processor = None\n",
    "gc.collect()\n",
    "with torch.no_grad():\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "guidance = threestudio.find(config['guidance_type'])(config['guidance'])\n",
    "prompt_processor = threestudio.find(config['prompt_processor_type'])(config['prompt_processor'])\n",
    "prompt_processor.configure_text_encoder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def figure2image(fig):\n",
    "    buf = io.BytesIO()\n",
    "    fig.savefig(buf)\n",
    "    buf.seek(0)\n",
    "    img = Image.open(buf)\n",
    "    return img\n",
    "\n",
    "def configure_other_guidance_params_manually(guidance, config):\n",
    "    # avoid reloading guidance every time change these params\n",
    "    guidance.cfg.grad_clip = config['guidance']['grad_clip']\n",
    "    guidance.cfg.guidance_scale = config['guidance']['guidance_scale']\n",
    "\n",
    "def run(config):\n",
    "    # clear gpu memory\n",
    "    rgb = None\n",
    "    grad = None\n",
    "    vis_grad = None\n",
    "    vis_grad_norm = None\n",
    "    loss = None\n",
    "    optimizer = None\n",
    "    target = None\n",
    "\n",
    "    gc.collect()\n",
    "    with torch.no_grad():\n",
    "        torch.cuda.empty_cache()\n",
    "    \n",
    "    configure_other_guidance_params_manually(guidance, config)\n",
    "\n",
    "    mode = config['mode']\n",
    "    \n",
    "    w, h = config['image']['width'], config['image']['height']\n",
    "    if mode == 'rgb':\n",
    "        target = nn.Parameter(torch.rand(1, h, w, 3, device=guidance.device))\n",
    "    else:\n",
    "        target = nn.Parameter(torch.randn(1, h, w, 4, device=guidance.device))\n",
    "\n",
    "    optimizer = torch.optim.AdamW([target], lr=1e-1, weight_decay=0)\n",
    "    num_steps = config['max_iters']\n",
    "    scheduler = get_cosine_schedule_with_warmup(optimizer, 100, int(num_steps*1.5)) if config['scheduler'] == 'cosine' else None\n",
    "\n",
    "    rgb = None\n",
    "    plt.axis('off')\n",
    "\n",
    "    img_array = []\n",
    "\n",
    "    try:\n",
    "        for step in tqdm(range(num_steps + 1)):\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            batch = {\n",
    "                'elevation': torch.Tensor([0]),\n",
    "                'azimuth': torch.Tensor([0]),\n",
    "                'camera_distances': torch.Tensor([1]),\n",
    "            }\n",
    "\n",
    "            loss = guidance(target, prompt_processor(), **batch, rgb_as_latents=(mode != 'rgb'))\n",
    "            loss['loss_sds'].backward()\n",
    "\n",
    "            grad = target.grad\n",
    "            optimizer.step()\n",
    "            if scheduler is not None:\n",
    "                scheduler.step()\n",
    "            \n",
    "            guidance.update_step(epoch=0, global_step=step)\n",
    "\n",
    "            if step % 5 == 0:\n",
    "                if mode == 'rgb':\n",
    "                    rgb = target\n",
    "                    vis_grad = grad[..., :3]\n",
    "                    vis_grad_norm = grad.norm(dim=-1)\n",
    "                else:\n",
    "                    rgb = guidance.decode_latents(target.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)\n",
    "                    vis_grad = grad\n",
    "                    vis_grad_norm = grad.norm(dim=-1)\n",
    "                \n",
    "                vis_grad_norm = vis_grad_norm / vis_grad_norm.max()\n",
    "                vis_grad = vis_grad / vis_grad.max()\n",
    "                img_rgb = rgb.clamp(0, 1).detach().squeeze(0).cpu().numpy()\n",
    "                img_grad = vis_grad.clamp(0, 1).detach().squeeze(0).cpu().numpy()\n",
    "                img_grad_norm = vis_grad_norm.clamp(0, 1).detach().squeeze(0).cpu().numpy()\n",
    "\n",
    "                fig, ax = plt.subplots(1, 3, figsize=(15, 5))\n",
    "                ax[0].imshow(img_rgb)\n",
    "                ax[1].imshow(img_grad)\n",
    "                ax[2].imshow(img_grad_norm)\n",
    "                ax[0].axis('off')\n",
    "                ax[1].axis('off')\n",
    "                ax[2].axis('off')\n",
    "                clear_output(wait=True)\n",
    "                plt.show()\n",
    "                img_array.append(figure2image(fig))\n",
    "    except KeyboardInterrupt:\n",
    "        pass\n",
    "    finally:\n",
    "        # browse the result\n",
    "        print(\"Optimizing process:\")\n",
    "        images = img_array\n",
    "        \n",
    "        if len(images) > 0:\n",
    "            # Set up the widgets\n",
    "            slider = IntSlider(min=0, max=len(images)-1, step=1, value=1)\n",
    "            output = Output()\n",
    "\n",
    "            def display_image(index):\n",
    "                with output:\n",
    "                    output.clear_output(wait=True)\n",
    "                    display(images[index])\n",
    "\n",
    "            # Link the slider to the display function\n",
    "            interact(display_image, index=slider)\n",
    "\n",
    "            # Display the widgets\n",
    "            # display(slider)\n",
    "            display(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config['mode'] = 'latent'\n",
    "run(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config['mode'] = 'rgb'\n",
    "run(config)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
